library(R.matlab)
library(ggplot2)
library(ggforce)
library(tidyverse)
library(mvtnorm)
library(grid)
library(svglite)
library(data.table)
library(showtext)
library(cowplot)
library(Cairo)


pth_reg <- "/System/Library/Fonts/Palatino.ttc"
font_add(family = "Palatino", regular = pth_reg)


setwd("/Users/jbhoward/Documents/Personal Research/JAERE Extension/FINAL SIMULATION")

out_pth <- file.path("figures", "output")

fpth <- file.path("simulation", "counterfactual_results.mat")
mat <- readMat(fpth)

get_probs <- function(key, tar_id) {
    all_results <- mat[names(mat) == key][[key]]
    objects <- all_results[, , tar_id]

    phi <- objects$phi
    del <- objects$del
    n_grid <- length(phi)

    ############## GET GCHI DATA ##################
    M = c(mat$m.phi, mat$m.del)                                        
    V = matrix(c(mat$v.phi^2, mat$v.cor*mat$v.phi*mat$v.del, mat$v.cor*mat$v.phi*mat$v.del, mat$v.del^2), ncol = 2) 
            
    g_chi_fun <- function(phi,del) dmvnorm(c(log(phi), log(del)), M, sigma =V, log = FALSE, checkSymmetry = TRUE)/(phi*del)
    g_chi <- matrix(rep(0, n_grid*n_grid), ncol = n_grid)
    for (n_phii in seq_len(n_grid)) {
        for (n_dell in seq_len(n_grid)) {
            g_chi[n_phii, n_dell] <- g_chi_fun(phi[n_phii], del[n_dell])
        }
    }
    g_chi <- g_chi/sum(g_chi)


    # initialize our data frame
    df <- data.frame("phi" = rep(phi, n_grid),
                    "del" = rep(del, each = n_grid),
                    "g_chi" = c(g_chi))

    vars <- c("prob.coastal", "prob.inland",
        "prob.inland.domestic", "prob.inland.exporter",
        "prob.coastal.domestic", "prob.coastal.exporter",
        "opt.e.coastal", "opt.e.inland",
        "inland.X", "coastal.X", "d.inland", "d.coastal",
        "share.fa.inland", "share.fa.coastal",
        "dxratio.inland", "dxratio.coastal",
        "xdratio.inland", "xdratio.coastal",
        "fa.inland", "fa.coastal", "int.x.coastal", "int.x.inland")

    for (var in vars) {
        df[var] <- c(objects[[var]])
    }

    loc_vars <- c("prob.inland", "prob.coastal")

    df$loc_prob_max <- apply(df[, loc_vars], 1, max, na.rm=TRUE)
    df$firm_loc <- colnames(df[, loc_vars])[apply(df[, loc_vars], 1, which.max)]

    df$firm_loc <- "prob.coastal"
    il_prob <- mean(df$prob.inland)
    df$firm_loc[df$prob.inland >= il_prob] <- "prob.inland"


    df$firm_loc[df$loc_prob_max == 0] <- "Exit"


    intl_vars <- c("prob.inland.domestic", "prob.inland.exporter")
    df$intl_inland <- colnames(df[, intl_vars])[apply(df[, intl_vars], 1, which.max)]

    intl_vars <- c("prob.coastal.domestic", "prob.coastal.exporter")
    df$intl_coastal <- colnames(df[, intl_vars])[apply(df[, intl_vars], 1, which.max)]

    cutoff <- mean(df$prob.coastal.exporter[df$firm_loc == "prob.coastal"])
    df$intl_coastal <- "prob.coastal.exporter"
    df$intl_coastal[df$prob.coastal.exporter < cutoff] <- "prob.coastal.domestic"

    df$firmtype <- "Exit"
    df$firmtype[df$firm_loc == "prob.inland"] <- df$intl_inland[df$firm_loc == "prob.inland"]
    df$firmtype[df$firm_loc == "prob.coastal"] <- df$intl_coastal[df$firm_loc == "prob.coastal"]


    df$firmtype[df$firmtype == "prob.inland.domestic"] <- "PD"
    df$firmtype[df$firmtype == "prob.inland.exporter"] <- "PI"
    df$firmtype[df$firmtype == "prob.coastal.domestic"] <- "CD"
    df$firmtype[df$firmtype == "prob.coastal.exporter"] <- "CI"

    df$type = "Exit"
    df$type[df$firmtype == "PD"] <- "Periphery-Domestic"
    df$type[df$firmtype == "PI"] <- "Periphery-International"
    df$type[df$firmtype == "CI"] <- "Core-International"
    df$type[df$firmtype == "CD"] <- "Core-Domestic"

    ttl <- sum(df[df$type != "Exit", "g_chi"])

    tmp <- df %>% filter(type != "Exit") %>%
            group_by(firm_loc) %>%
            summarize(mass = n()) %>%
            ungroup() %>%
            mutate(ttl = sum(mass)) %>%
            mutate(mass = mass/ttl)
    return(tmp)
}

SIMS <- list("Matching" = "endog.results", "Random" = "exog.results")
TRADE_COSTS <- c("Frictionless Trade" = 1, "15% Trade Cost" = 8)

out <- list()
counter <- 0

for (key in SIMS) {
    for (tar_id in TRADE_COSTS) {
        counter <- counter + 1
        tmp <- get_probs(key, tar_id)
        tmp$`Simulation Key` <- names(SIMS)[SIMS == key]
        tmp$tariff <- names(TRADE_COSTS)[TRADE_COSTS == tar_id]
        out[[counter]] <- tmp
    }
}

out <- rbindlist(out)


p <- ggplot(data=out[firm_loc == "prob.coastal"], aes(x=factor(`Simulation Key`),
                     y = as.numeric(mass), fill=`Simulation Key`, linetype=`Simulation Key`)) +
  geom_col(position = "dodge", width = 0.75, linewidth = 1, color = 'darkgrey') +
  scale_fill_brewer(palette="Paired", direction=-1) +
  geom_text(aes(label = `Simulation Key`), vjust = -0.5, colour = "black", size = 3.25,
            position = position_dodge(width = 1), family = "Palatino",
            lineheight = 0.7) +
  scale_linetype_manual(values=c("blank", "dashed")) +
  scale_y_continuous(breaks = c(0.2, 0.4, 0.6), limits = c(0, 0.7)) + 
  theme_minimal() + 
  labs(fill = NULL, x = NULL, y = "Core Share of Firms", linetype = NULL) + 
  theme(legend.position = "none", 
    text = element_text(size = 12, colour = "black", family = "Palatino"),
    axis.text = element_text(color = "black"), 
    panel.spacing = unit(2, "lines"), legend.key.width= unit(1, 'cm'), axis.text.x = element_blank()) +
  facet_grid(~ tariff, switch="both")

showtext_auto()
fn <- "fig09_cf_loc_prob.eps"

cairo_ps(file.path(out_pth, fn), width = 4, height = 3, fallback_resolution = 320)
print(p)
dev.off()

showtext_auto(FALSE)